import jsonlines
from tqdm import tqdm
import re
import sys



model = sys.argv[1]

pattern = re.compile("Hypothesis[1-2]{1}")
data = [d for d in jsonlines.open(f"./output/{model}.jsonl", "r")]

# data = [d for i, d in enumerate(data) if i in indexes]

results = {}
predictions = []
mapping = {"Hypothesis1": 0, "Hypothesis2": 1, "None": 2}

for d in tqdm(data):
    if 'llama' not in model:
        prediction = d['answer'].split("<|im_start|>assistant")[1].replace(" ", "")
    else:
        try:
            prediction = d['answer'].split("\n\nAnswer:")[1].replace(" ", "")
        except:
            prediction = d['answer']
    try:
        prediction = re.findall(pattern, prediction)[0]
    except:
        prediction = "None"
    rule = d['general_rule']
    if rule not in results:
        results[rule] = {"total": 0, "correct": 0}
    results[rule]["total"] += 1

    if mapping[prediction] == d['label']:
        results[rule]["correct"] += 1
        predictions.append(1)
    else:
        predictions.append(0)
        continue


filtered = {}
total = 0
for key in results:
    if results[key]["total"] >= 2:
        total += results[key]["total"]
        filtered[key] = results[key]

# count = sum([filtered[r]["total"]==filtered[r]["correct"] or filtered[r]["correct"] == 0 for r in filtered])
count = sum([filtered[r]["correct"] == filtered[r]["total"] for r in filtered])
count_soft = sum([filtered[r]["correct"] / filtered[r]["total"] for r in filtered])
print(f"Filtered Size: {total}")
print(f"Filtered Facts: {len(filtered)}")
print(f"Original Facts: {len(results)}")
print(f"Abstract ACC Hard: {count/len(filtered)}")
print(f"Abstract ACC Soft: {count_soft/len(filtered)}")
print(f"ACC: {sum(predictions)/len(data)}")



